package org.zjvis.datascience.web.controller;

import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.ArrayList;

import com.alibaba.fastjson.JSONArray;

import java.util.LinkedHashMap;
import java.util.List;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.sql.ResultSet;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.zjvis.datascience.common.annotation.ProjectAuth;
import org.zjvis.datascience.common.constant.Constant;
import org.zjvis.datascience.common.constant.DatabaseConstant;
import org.zjvis.datascience.common.constant.DatasetConstant;
import org.zjvis.datascience.common.dto.*;
import org.zjvis.datascience.common.enums.ModelStatusEnum;
import org.zjvis.datascience.common.enums.ProjectAuthEnum;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.model.ApiResult;
import org.zjvis.datascience.common.model.ApiResultCode;
import org.zjvis.datascience.common.util.JwtUtil;
import org.zjvis.datascience.common.util.RestTemplateUtil;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.vo.*;
import org.zjvis.datascience.service.*;
import com.alibaba.fastjson.JSONObject;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;
import org.zjvis.datascience.service.dataset.DatasetService;
import org.zjvis.datascience.common.model.Table;
import org.zjvis.datascience.service.mapper.DatasetProjectMapper;

import java.util.Map;

import static org.zjvis.datascience.common.constant.DatabaseConstant.DEFAULT_DATASET_ID;

/**
 * @description 机器学习模型训练接口 Controller
 * @date 2021-11-29
 */
@RequestMapping("/model")
@RestController
@Api(tags = "model", description = "机器学习模型训练接口")
@Validated
public class ModelController {
    private final static Logger logger = LoggerFactory.getLogger("ModelController");

    @Autowired
    private TaskService taskService;

    @Autowired
    private MinioService minioService;

    @Autowired
    private MLModelService mlModelService;

    @Autowired
    private FolderService folderService;

    @Autowired
    private DatasetService datasetService;

    @Autowired
    private GPDataProvider gpDataProvider;

    @Autowired
    private DatasetProjectMapper datasetProjectMapper;

    @Autowired
    private RestTemplateUtil restTemplateUtil;

//    @DubboReference
//    private GPDataService gpDataService;

    @Autowired
    private ServletContext servletContext;

    public static String MLMODEL_BUCKET = "ml-model";

    public static long RUNNING_LIMIT = 5;

    public static long ENGINE_SWITCH = 7000000;

    private boolean isContainKey(String name, String searchKey) {
        if (searchKey == null){
            return true;
        }
        String lowerName = name.toLowerCase();
        String lowerKey = searchKey.toLowerCase();
        return lowerName.contains(lowerKey);
    }

    @PostMapping(value = "/train")
    @ResponseBody
    @ApiOperation(value = "执行模型训练", notes = "执行模型训练")
    //ApiResult<List<Long>>
    public ApiResult train(HttpServletRequest request, @RequestBody MLModelDTO model) {
        long userId = JwtUtil.getCurrentUserId();
        //TODO userId security check
        long numRunning = mlModelService.queryNumRunning(userId);
        if (numRunning >= RUNNING_LIMIT) {
            return ApiResult.valueOf(ApiResultCode.MODEL_RUNNING_LIMIT);
        }
        MLModelDTO modelDTO = mlModelService.queryMetricsById(model.getId());
        if (modelDTO.getSourceTable() == null) {
            return ApiResult.valueOf(ApiResultCode.MODEL_NO_SOURCE);
        }
        model.setUserId(userId);
        model.setStatus("RUNNING");
        model.setRunTime(0f);
        mlModelService.updateTrainTime(model.getId());
        mlModelService.update(model);

        String res = null;
        try {
            JSONObject params = new JSONObject();
            params.put("apiPath", "ml_model");
            params.put("execution", "train");
            params.put("model_id", model.getId());
            res = mlModelService.beginTraining(params);
            //res = mlModelService.submitFlaskJob(params);
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (null != res) {
            return ApiResult.valueOf(ApiResultCode.SUCCESS);
        } else {
            return ApiResult.valueOf(ApiResultCode.SYS_ERROR);
        }
    }
//
//    @PostMapping(value = "/predict")
//    @ResponseBody
//    @ApiOperation(value = "执行模型预测", notes = "执行模型预测")
//    public ApiResult predict(HttpServletRequest request, @RequestBody JSONObject params) {
//        Long taskId = params.getLong("taskId");
//        TaskDTO task = taskService.queryById(taskId);
//        JSONObject dataJson = JSONObject.parseObject(task.getDataJson());
//        JSONArray outputs = dataJson.getJSONArray("output");
//        JSONObject output = outputs.getJSONObject(0);
//        Long modelId = dataJson.getLong("modelId");
//        String target = "pipeline." + "ml_" + task.getUserId() + "_";
//        long timeStamp = System.currentTimeMillis();
////        String lastTimeStamp = String.valueOf(timeStamp).substring(0,13);
//        target += timeStamp;
//        dagScheduler.setLastTimeStamp(task, timeStamp);
//        taskService.update(task);
//        TaskInstanceDTO taskInstance = new TaskInstanceDTO();
//        taskInstance.setTaskId(taskId);
//        taskInstance.setPipelineId(task.getPipelineId());
//        taskInstance.setProjectId(task.getProjectId());
//        taskInstance.setUserId(task.getUserId());
//        taskInstance.setParentId(task.getParentId());
//        JSONObject newParam = dataJson.getJSONObject("param");
//        newParam.put("execution", "predict");
//        String featureX = dataJson.getString("feature_X");
//        newParam.put("feature_col", featureX);
//        newParam.put("source", output.getString("tableName"));
//        newParam.put("model_saved_path", dataJson.getString("modelPath"));
//        newParam.put("model_id", dataJson.getLong("modelId"));
//        newParam.put("target", target);
//        taskInstance.setDataJson(newParam.toJSONString());
//        Long instanceId = taskInstanceService.save(taskInstance);
//
//        JSONArray inputs = dataJson.getJSONArray("input");
//        JSONObject input = inputs.getJSONObject(0);
//        String source = input.getString("tableName");
//
//        //String appArgs = params.getString("param");
//        String appArgs = "--id " + modelId + " --target " + target + " --exe predict"
//                + " --source " + source + " --feature_col " + featureX
//                + " --task_id " + taskId + " --instance_id " + instanceId;
////        appArgs += " --source " + source;
//        mlModelService.exec(appArgs);
//        return ApiResult.valueOf(ApiResultCode.SUCCESS);
//    }


    @PostMapping(value = "/create")
    @ResponseBody
    @ApiOperation(value = "添加模型", notes = "添加新建模型")
    public ApiResult<Long> create(@RequestBody MLModelVO vo) {
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        vo.setData(new JSONObject());
        vo.setParam(new JSONObject());
        if (vo.getProjectId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        List<MLModelDTO> modelDTOS = mlModelService.queryByProjectId(vo);
        if (modelDTOS.size() >= 30) {
            return ApiResult.valueOf(ApiResultCode.MODEL_NUM_LIMIT);
        }
        for (MLModelDTO model : modelDTOS) {
            if (model.getName().equals(vo.getName())) {
                return ApiResult.valueOf(ApiResultCode.MODEL_NAME_DUP);
            }
        }
        vo.setStatus("PENDING");
        Long id = mlModelService.save(vo.toMLModel());

        return ApiResult.valueOf(id);
    }

    @PostMapping(value = "/update")
    @ResponseBody
    @ApiOperation(value = "更新模型", notes = "更新模型信息")
    public ApiResult<Long> update(@RequestBody MLModelVO vo) {
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        vo.setData(new JSONObject());
        if (vo.getProjectId() <= 0L || vo.getId() == null) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        if (vo.getName() != null) {
            List<MLModelDTO> modelDTOS = mlModelService.queryByProjectId(vo);
            for (MLModelDTO modelDTO : modelDTOS) {
                if (!modelDTO.getId().equals(vo.getId()) && modelDTO.getName().equals(vo.getName())) {
                    return ApiResult.valueOf(ApiResultCode.MODEL_NAME_DUP);
                }
            }
        }
        MLModelDTO model = mlModelService.queryMetricsById(vo.getId());
        if (model.getInPanel() != 0) {
            Long currentId = model.getId();
            model.setId(null);
            Long newId = mlModelService.save(model);
            model.setId(newId);
            model.setInvisible(ModelStatusEnum.DELETED_IN_FIELD.getVal());
            mlModelService.update(model);

            model.setId(currentId);
            model.setInPanel(0L);
            model.setFolderId(0L);
            model.setInvisible(ModelStatusEnum.NORMAL.getVal());
            mlModelService.update(model);
//            return ApiResult.valueOf(ApiResultCode.SUCCESS);
        }
        List<ProjectDatasetDTO> pdds = new ArrayList<>();
        pdds = datasetProjectMapper.queryProjectDataset(vo.getProjectId(), "");
        if (vo.getSourceTable() != null && pdds.size() != 0) {
            for (ProjectDatasetDTO pdd : pdds) {
                String data = pdd.getDataJson();
                JSONObject dataJson = JSONObject.parseObject(data);
                if (dataJson.containsKey("table") && ("dataset." + dataJson.get("table")).equals(vo.getSourceTable())) {
                    vo.setSourceId(pdd.getId());
                    vo.setSourceName(pdd.getDatasetName());
                    vo.setNum(JSONObject.parseObject(pdd.getDataJson()).getLong("totalRow"));
                }
            }
        }
        String existedParam = model.getModelParam();
        JSONObject param = JSONObject.parseObject(existedParam);
        if (param == null) {
            param = new JSONObject();
        }
        JSONObject newParam = vo.getParam();
        if (newParam != null) {
            for (String newkey : newParam.keySet()) {
                param.put(newkey, newParam.get(newkey));
            }
        }
        if (vo.getSourceTable() != null) {
            param.put("source", vo.getSourceTable());
        }

        param.put("usr_id", userId);
        param.put("model_id", vo.getId());
        param.put("execution", "train");
        if (vo.getAlgorithm() != null) {
            param.put("algo", vo.getAlgorithm());
        }
        vo.setParam(param);

        vo.setStatus("PENDING");
        MLModelDTO dto = vo.toMLModel();
//        JSONObject nothing = new JSONObject();
//        nothing.put("k","v");
//        vo.setData(nothing);
        dto.setModelInfo("{}");
        dto.setRunTime(0f);
        mlModelService.update(dto);

        return ApiResult.valueOf(vo.getId());
    }

    @SneakyThrows
    @PostMapping(value = "/delete")
    @ResponseBody
    @ApiOperation(value = "删除模型", notes = "删除模型文件及模型信息")
    public ApiResult<Long> delete(@RequestBody MLModelVO vo) {
        if (vo.getUserId() == null) {
            vo.setUserId(JwtUtil.getCurrentUserId());
        }

        if (vo.getProjectId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }

        MLModelDTO dto = mlModelService.queryMetricsById(vo.getId());
        if (!dto.getInvisible().equals(ModelStatusEnum.COPIED.getVal())) {
            //如果panel中有这个模型，仅将模型设置为不可见
            if (dto.getInPanel() == 1 || dto.getInPanel() == 2) {
                mlModelService.setInvisible(vo.getId());
                return ApiResult.valueOf(ApiResultCode.SUCCESS);
            }
            String minioModelPath = dto.getModelPath();
            if (minioModelPath != null) {
                if (isContainKey(minioModelPath, "spark")) {
                    String modelpath = dto.getModelPath().split("ml-model/")[1];

                    String prefixData = modelpath + "/data";
                    String prefixMetaData = modelpath + "/metadata";
                    List<String> dataobjectNames = minioService.listObjects(MLMODEL_BUCKET, prefixData);
                    List<String> metaobjectNames = minioService.listObjects(MLMODEL_BUCKET, prefixMetaData);

                    List<String> objectNames = new ArrayList<>();

                    for (String objectName : dataobjectNames) {
                        objectNames.add(prefixData + objectName);
                    }

                    for (String objectName : metaobjectNames) {
                        objectNames.add(prefixMetaData + objectName);
                    }
                    minioService.deleteObjects(MLMODEL_BUCKET, objectNames);
                } else {
                    minioService.deleteObject(MLMODEL_BUCKET, minioModelPath.split("ml-model/")[1]);
                }
            }
        }
        mlModelService.delete(vo.getId());

        return ApiResult.valueOf(vo.getId());
    }

    @PostMapping(value = "/queryMetricsById")
    @ResponseBody
    @ApiOperation(value = "查询模型评估报告", notes = "根据id查询单个模型评估报告")
    public ApiResult<MLModelVO> queryMetricsById(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) MLModelVO vo) {
        if (vo.getId() == null || vo.getId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }

        MLModelDTO model = mlModelService.queryMetricsById(vo.getId());

        if (model.getUserId() != JwtUtil.getCurrentUserId()) {
            return ApiResult.valueOf(ApiResultCode.NOT_YOUR_MODEL);
        }
        Long sourceId = model.getSourceId();
        String size = "";
        int rowNum = 0;
        int colNum = 0;
        if (sourceId != null) {
            JSONObject sourceInfo = datasetService.queryDataById(sourceId, null, null);
            if (sourceInfo.getInteger("code") == 200) {
                JSONObject sourceData = sourceInfo.getJSONObject("data");
                size = sourceData.getString("size");
                colNum = sourceData.getJSONArray("head").size();
                rowNum = sourceData.getInteger("count");
            }
        }
        MLModelVO ret = model.view();

        JSONObject data = ret.getData();
        if (data != null) {
            data.put("sourceSize", size);
            data.put("colNum", colNum);
            data.put("rowNum", rowNum);
            ret.setData(data);
        }

        return ApiResult.valueOf(ret);
    }

    @PostMapping(value = "/queryOutputById")
    @ResponseBody
    @ApiOperation(value = "查询模型输出结果", notes = "根据id查询单个模型输出结果")
    public ApiResult<JSONObject> queryOutputById(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject params) {
        Long modelId = params.getLong("id");
        Long projectId = params.getLong("projectId");

        if (modelId == null || modelId <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        String start = "0";
        String end = "1000000";
        if (params.getLong("start") != null && params.getLong("end") != null) {
            start = String.valueOf(params.getLong("start"));
            end = String.valueOf(params.getLong("end"));
        }
        if (params.getLong("page") != null) {
            end = String.valueOf(10 * params.getLong("page"));
            start = String.valueOf(10 * params.getLong("page") - 9);
        }
        MLModelDTO model = mlModelService.queryMetricsById(modelId);
        if (model.getUserId() != JwtUtil.getCurrentUserId()) {
            return ApiResult.valueOf(ApiResultCode.NOT_YOUR_MODEL);
        }
        MLModelVO ret = model.view();
        JSONObject data = ret.getData();
        String tableName = data.getString("out_table");

        //获取gp数据库中数据集库的连接（指定id为1
        Connection con = null;
        ResultSet rs = null;
        PreparedStatement ps = null;
        JSONObject output = new JSONObject();
        try {
            String schema = "ml_model";
            String table = tableName.split("\\.")[1];
            String sortSql = "";
            String limitSql = "where _record_id_ between " + start + " AND " + end + "order by _record_id_";

            //String.format(DatabaseConstant.GP_LIMIT_SQL,DatabaseConstant.GP_PREVIEW_COUNT);
            String sql = String.format(DatabaseConstant.GP_SELECT_SQL, schema, SqlUtil.formatPGSqlColName(table), sortSql, limitSql);

//            con = gpDataProvider.getConn(DatabaseConstant.DEFAULT_DATASET_ID);
//            ps = con.prepareStatement(sql);
//            rs = ps.executeQuery();
            String dataSourceKey = servletContext.getAttribute(Constant.DEFAULT_DATA_SOURCE_KEY).toString();
            JSONArray queryResult = gpDataProvider.executeQuerySQL(dataSourceKey, sql, JSONArray.class);
            Map<String, Object> metaMap = queryResult.getJSONObject(0).getInnerMap();
//            ResultSetMetaData meta = rs.getMetaData();
            int colCount = metaMap.size();
            /* 生成head结构 */
            JSONArray heads = new JSONArray();
            List<String> colNames = new ArrayList<>();
//            for (int i = 1; i < colCount + 1; i++) {
//                if (DatasetConstant.DEFAULT_ID_FIELD.equals(meta.getColumnName(i))) {
//                    continue;
//                }
//                JSONObject head = new JSONObject();
//                String name = meta.getColumnName(i);
//                head.put("name", name);
//                heads.add(head);
//                colNames.add(name);
//            }
            for (String colName: metaMap.keySet()){
                if (DatasetConstant.DEFAULT_ID_FIELD.equals(colName)) {
                    continue;
                }
                JSONObject head = new JSONObject();
                head.put("name", colName);
                heads.add(head);
                colNames.add(colName);
            }
            output.put("head", heads);
            /* 生成data结构 */
//            JSONArray row = new JSONArray();
//            JSONObject column = new JSONObject();
//            for (JSONObject rowObj : queryResult){
//                column = new JSONObject();
//                for (String colName : colNames) {
//                    if (DatasetConstant.DEFAULT_ID_FIELD.equals(colName)) {
//                        continue;
//                    }
//                    column.put(colName, rowObj.getString(colName));
//                }
//                row.add(column);
//            }
            queryResult.forEach(obj ->{((LinkedHashMap)obj).remove(DatasetConstant.DEFAULT_ID_FIELD);});
//            while (rs.next()) {
//                column = new JSONObject();
//                for (String colName : colNames) {
//                    if (DatasetConstant.DEFAULT_ID_FIELD.equals(colName)) {
//                        continue;
//                    }
//                    column.put(colName, rs.getString(colName));
//                }
//                row.add(column);
//            }
            output.put("data", queryResult);
            Long totalPages = (Long) (model.getNum() / 10);
            Long dataNum = model.getNum();
            if (dataNum % 10 != 0) {
                totalPages += 1;
            }
            output.put("totalPages", totalPages);
        } catch (Exception e) {
            logger.error(e.getMessage());
            return (ApiResult.valueOf(ApiResultCode.GET_OUT_FAIL));
        }
        return ApiResult.valueOf(output);
    }


    @PostMapping(value = "/queryByProjectId")
    @ResponseBody
    @ApiOperation(value = "查询模型列表", notes = "根据项目id查询所有模型")
    public ApiResult<List<MLModelVO>> queryByProjectId(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) MLModelVO vo) {
        if (vo.getProjectId() == null || vo.getProjectId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        List<MLModelVO> ret = new ArrayList<>();
        if (vo.getStatus() == null || vo.getStatus() == "") {
            List<MLModelDTO> models = mlModelService.queryByProjectId(vo);
            List<ProjectDatasetDTO> pdds = new ArrayList<>();
            pdds = datasetProjectMapper.queryProjectDataset(vo.getProjectId(), "");
            if (models != null && !models.isEmpty()) {
                for (MLModelDTO model : models) {
                    if (model.getSourceTable() != null && pdds.size() != 0) {
                        for (ProjectDatasetDTO pdd : pdds) {
                            String data = pdd.getDataJson();
                            JSONObject dataJson = JSONObject.parseObject(data);
                            if (dataJson.containsKey("table") && dataJson.get("table").equals(model.getSourceTable())) {
                                model.setSourceId(pdd.getId());
                                model.setSourceName(pdd.getDatasetName());
                                mlModelService.update(model);
                            }
                        }
                    }
                    if (model.getUserId() == userId) {
                        ret.add(0, model.view());
                    } else {
                        ret.add(model.view());
                    }
                }
            }
        } else {
            List<MLModelDTO> models = mlModelService.queryByStatus(vo);
            if (models != null && !models.isEmpty()) {
                for (MLModelDTO model : models) {
                    if (model.getUserId() == userId) {
                        ret.add(0, model.view());
                    } else {
                        ret.add(model.view());
                    }
                }
            }
        }
        //MLModelVO ret = model.view();
        //把当前用户所拥有的model放在列表首位
        return ApiResult.valueOf(ret);
    }

    @PostMapping(value = "/killTraining")
    @ResponseBody
    @ApiOperation(value = "停止训练", notes = "根据application id杀掉spark进程")
    public ApiResult<Long> killTraining(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) MLModelVO model) {
        long userId = JwtUtil.getCurrentUserId();
        MLModelDTO modelDTO = mlModelService.queryMetricsById(model.getId());
        if (modelDTO.getSourceTable() == null) {
            return ApiResult.valueOf(ApiResultCode.MODEL_NO_SOURCE);
        }
        model.setUserId(userId);
        model.setStatus("KILLED");
        mlModelService.updateStatus(model);
        boolean res = mlModelService.killTraining(model);

        return ApiResult.valueOf(ApiResultCode.SUCCESS);
    }

    @PostMapping(value = "/queryTrainingStatus")
    @ResponseBody
    @ApiOperation(value = "查询训练状态", notes = "查询模型训练状态")
    public ApiResult<JobStatusVO> queryTrainingStatus(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject params) {
//        String appId = params.getString("applicationId");
        Long id = params.getLong("id");
//        if (appId == null){
//            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
//        }
        JobStatusVO jsvo = mlModelService.queryJobstatus("", id);
        return ApiResult.valueOf(jsvo);
    }

    @PostMapping(value = "/queryByStatus")
    @ResponseBody
    @ApiOperation(value = "根据训练status查询模型列表", notes = "根据训练status查询所有模型")
    public ApiResult<List<MLModelVO>> queryByStatus(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) MLModelVO vo) {
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        List<MLModelDTO> models = mlModelService.queryByStatus(vo);
        List<MLModelVO> ret = new ArrayList<>();
        if (models != null && !models.isEmpty()) {
            for (MLModelDTO model : models) {
                if (model.getUserId() == userId) {
                    ret.add(0, model.view());
                } else {
                    ret.add(model.view());
                }
            }
        }
        return ApiResult.valueOf(ret);
    }

    @PostMapping(value = "/exportToFolder")
    @ResponseBody
    @ApiOperation(value = "导出模型", notes = "导出模型到列表中")
    public ApiResult<Long> exportToFolder(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) MLModelVO vo) {
        if (vo.getId() == null || vo.getId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        MLModelDTO modelDTO = mlModelService.queryMetricsById(vo.getId());
        String modelName = vo.getName();
        Long folderId = vo.getFolderId();
        String modelDesc = vo.getModelDesc();
        if (modelName != null && !modelName.equals("")) {
            List<MLModelDTO> modelDTOSinPanel = mlModelService.queryModelPanel(vo);
            for (MLModelDTO model : modelDTOSinPanel) {
                if (!model.getId().equals(vo.getId()) && model.getName().equals(vo.getName())) {
                    return ApiResult.valueOf(ApiResultCode.MODEL_NAME_DUP);
                }
            }
            modelDTO.setName(modelName);
        }
        if (modelDesc != null && !modelDesc.equals("")) {
            modelDTO.setModelDesc(modelDesc);
        }
        if (vo.getFolderId() == null) {
            modelDTO.setInPanel(1L);
            modelDTO.setFolderId(0L);
            //检查是否需要更改folder表
        } else {
            //校验folder中的project id与model中的project id
            FolderDTO folderDTO = mlModelService.getFolderById(vo);
            if (!folderDTO.getProjectId().equals(vo.getProjectId())) {
                return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
            }
            modelDTO.setInPanel(2L);
            modelDTO.setFolderId(folderId);
        }
        mlModelService.update(modelDTO);
        return ApiResult.valueOf(ApiResultCode.SUCCESS);
    }

    @PostMapping(value = "/createFolder")
    @ResponseBody
    @ApiOperation(value = "添加文件夹", notes = "添加新建文件夹")
    public ApiResult<Long> createFolder(@RequestBody FolderVO vo) {
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        vo.setModelInfo(new JSONObject());
        if (vo.getProjectId() <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }

        if (vo.getName() == null) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        Long id = folderService.createFolder(vo.toFolder());
        if (id == -1L) {
            return ApiResult.valueOf(ApiResultCode.FOLDER_NAME_DUP);
        }
        if (id == -2L) {
            return ApiResult.valueOf(ApiResultCode.FOLDER_NUM_LIMIT);
        }
        if (id == -3L) {
            return ApiResult.valueOf(ApiResultCode.FOLDER_NAME_INVALID);
        }
        return ApiResult.valueOf(id);
    }

    @PostMapping(value = "/queryByFolder")
    @ResponseBody
    @ApiOperation(value = "查询文件夹中的模型", notes = "根据文件夹查询模型")
    public ApiResult<List<MLModelVO>> queryByFolder(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) FolderVO vo) {
        if (vo.getName() == null) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        long userId = JwtUtil.getCurrentUserId();
        vo.setUserId(userId);
        List<MLModelDTO> models = mlModelService.queryFolder(vo);

        List<MLModelVO> ret = new ArrayList<>();

        for (MLModelDTO model : models) {
            ret.add(model.view());
        }
        return ApiResult.valueOf(ret);
    }

    @PostMapping(value = "/queryColumns")
    @ResponseBody
    @ApiOperation(value = "根据数据表查询列名", notes = "根据数据表查询列名")
    public ApiResult<Map<String, Integer>> queryColumns(HttpServletRequest request, @RequestBody MLModelVO vo) {
        MLModelDTO model = mlModelService.queryMetricsById(vo.getId());
        String tableName = model.getSourceTable();
        Map<String, Integer> columnTypes = gpDataProvider
                .getColumnTypesOriginal(new Table(DEFAULT_DATASET_ID, tableName));

//        String defaultDataSourceKey = servletContext.getAttribute(Constant.DEFAULT_DATA_SOURCE_KEY).toString();
//        DataPattern dataPattern = new org.zjvis.datacenter.service.vo.table.Table(defaultDataSourceKey, tableName.split("\\.")[0], tableName.split("\\.")[1]);
//        List<Column> columnsInfo = gpDataService.queryTableMeta(dataPattern).getResult();
//        Map<String, Integer> columnTypes = Maps.newLinkedHashMap();
//        for (Column col: columnsInfo){
//            columnTypes.put(col.getLabel(), col.getType());
//        }
        return ApiResult.valueOf(columnTypes);
    }

    @PostMapping(value = "/userTableCheck")
    @ResponseBody
    @ApiOperation(value = "核验数据集所属", notes = "核验数据集所属用户")
    public Integer userTableCheck(HttpServletRequest request, @RequestBody JSONObject param) {
        Long pipelineId = param.getLong("pipelineId");
        String tableName = param.getString("tableName").split("\\.")[1];
        List<TaskDTO> tasks = taskService.queryByPipeline(pipelineId);
        for (TaskDTO task : tasks) {
            String data = task.getDataJson();
            JSONObject dataJson = JSONObject.parseObject(data);
            JSONArray output = dataJson.getJSONArray("output");
            if (output.size() != 0) {
                JSONObject output1 = (JSONObject) output.get(0);
                String table = output1.getString("tableName");
                Character lastChar = table.charAt(table.length() - 1);
                if (lastChar.equals('_')) {
                    Long parentTimeStamp = dataJson.getLong("lastTimeStamp");
                    table = table + parentTimeStamp;
                }
                if (table.contains(".")) {
                    table = table.split("\\.")[1];
                }
                if (tableName.contains(table)) {
                    return 1;
                }
            }
        }
        return 0;
    }

    @PostMapping(value = "/deleteFolder")
    @ResponseBody
    @ApiOperation(value = "删除文件夹", notes = "根据id删除文件夹")
    public ApiResult<Void> deleteFolder(HttpServletRequest request,
                                        @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject param) {
        long userId = JwtUtil.getCurrentUserId();
        Long pipelineId = param.getLong("pipelineId");
        Long folderId = param.getLong("id");
        Long projectId = param.getLong("projectId");
        FolderDTO folder = new FolderDTO();
        folder.setId(folderId);
        folder.setUserId(userId);
        folder.setProjectId(projectId);
        Long delRes = folderService.deleteFolder(folder, pipelineId);
        if (delRes != 0L) {
            return ApiResult.valueOf(ApiResultCode.FOLDER_IN_USE);
        }
        return ApiResult.valueOf(ApiResultCode.SUCCESS);
    }

    @PostMapping(value = "/deleteModelInFolder")
    @ResponseBody
    @ApiOperation(value = "删除文件夹内的模型", notes = "删除文件夹内的模型")
    public ApiResult<Void> deleteModelInFolder(HttpServletRequest request,
                                               @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject param) {
        long userId = JwtUtil.getCurrentUserId();
        Long pipelineId = param.getLong("pipelineId");
        Long modelId = param.getLong("id");
        Long authRes = mlModelService.modelDelAuth(modelId, pipelineId);
        if (authRes != 0L) {
            return ApiResult.valueOf(ApiResultCode.MODEL_IN_USE);
        }
        MLModelDTO model = new MLModelDTO();
        model.setId(modelId);
        model.setProjectId(param.getLong("projectId"));
        model.setUserId(userId);
        folderService.deleteModelInFolder(model);
        return ApiResult.valueOf(ApiResultCode.SUCCESS);
    }

    @PostMapping(value = "/flaskTest")
    @ResponseBody
    @ApiOperation(value = "删除文件夹内的模型", notes = "删除文件夹内的模型")
    public ApiResult<String> flaskTest(HttpServletRequest request, @RequestBody JSONObject param) {
//        String name = param.getString("name");
        String body = "failed";
        try {
            body = restTemplateUtil.submitFlaskJob(param);
        } catch (IOException e) {
            e.printStackTrace();
        }
        String a = "1";
        return ApiResult.valueOf(body);
    }

    @PostMapping(value = "/saveOutputToDataSource")
    @ResponseBody
    @ApiOperation(value = "保存模型输出结果到当前项目数据集", notes = "保存模型输出结果到当前项目数据集")
    public ApiResult<Boolean> saveOutputToDataSource(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject params) {
        Boolean ret = false;
        Long modelId = params.getLong("modelId");
        Long projectId = params.getLong("projectId");
        String datasetName = params.getString("datasetName");
        Long categoryId = params.getLong("categoryId");
        String categoryName = params.getString("categoryName");
        if (modelId == null || modelId <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        TaskSaveVO vo = new TaskSaveVO();
        vo.setProjectId(projectId);
        vo.setCategoryName(categoryName);
        vo.setTaskId(0L);
        vo.setTableNameML("output_" + modelId);
        vo.setCategoryId(categoryId);
        vo.setDatasetName(datasetName);
        if ((vo.getCategoryId() == null && StringUtils.isBlank(vo.getCategoryName())) ||
                (StringUtils.isBlank(vo.getDatasetName()))) {
            logger.error("API /task/saveToDataSource failed, since {}",
                    BaseErrorCode.TASK_SAVE_TO_DATASET_PARAM.getMsg());
            throw new DataScienceException(BaseErrorCode.TASK_SAVE_TO_DATASET_PARAM);
        }
        ret = taskService.saveToDataSource(vo);
        return ApiResult.valueOf(ret);
    }

    @PostMapping(value = "/saveOutputToDataset")
    @ResponseBody
    @ApiOperation(value = "保存模型输出结果到数据管理", notes = "保存模型输出结果到数据管理")
    public ApiResult<Long> saveOutputToDataset(HttpServletRequest req, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject params) {
        Long ret;
        Long modelId = params.getLong("id");
        Long projectId = params.getLong("projectId");
        String datasetName = params.getString("datasetName");
        Long categoryId = params.getLong("categoryId");
        if (modelId == null || modelId <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        TaskSaveVO vo = new TaskSaveVO();
        vo.setProjectId(projectId);
        vo.setTaskId(0L);
        vo.setTableNameML("output_" + modelId);
        vo.setCategoryId(categoryId);
        vo.setDatasetName(datasetName);
        if ((vo.getCategoryId() == null && StringUtils.isBlank(vo.getCategoryName())) ||
                (StringUtils.isBlank(vo.getDatasetName()))) {
            logger.error("API /task/saveToDataSource failed, since {}",
                    BaseErrorCode.TASK_SAVE_TO_DATASET_PARAM.getMsg());
            throw new DataScienceException(BaseErrorCode.TASK_SAVE_TO_DATASET_PARAM);
        }
        ret = taskService.saveToDataset(vo);
        return ApiResult.valueOf(ret);
    }

    @RequestMapping(value = "/downloadOutput")
    @ResponseBody
    @ApiOperation(value = "下载模型输出结果为csv", notes = "下载模型输出结果为csv")
    public void download(HttpServletRequest request, HttpServletResponse response,
                         @RequestParam("modelId") Long modelId) {
        TaskSaveVO vo = new TaskSaveVO();
        vo.setTaskId(0L);
        vo.setTableNameML("ml_model.output_" + modelId);
        String tableName = vo.getTableNameML();
        String sheetName = "model_output";
        List<String> tables = new ArrayList<>();
        tables.add(tableName);
        taskService.downloadData(response, tables, sheetName);
    }



    @PostMapping(value = "/search")
    @ResponseBody
    @ApiOperation(value = "查找模型", notes = "查找模型")
    public ApiResult<List<MLModelVO>> search(HttpServletRequest request, @RequestBody @ProjectAuth(auth = ProjectAuthEnum.READ) JSONObject params) {
        Long projectId = params.getLong("projectId");
        String searchKey = params.getString("searchKey");
        String modelType = params.getString("modelType");
        String status = params.getString("status");

        if (projectId == null || projectId <= 0L) {
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        if (modelType == null && searchKey == null){
            return ApiResult.valueOf(ApiResultCode.PARAM_ERROR);
        }
        if (modelType != null){
            if (modelType.equals("ALL")){
                modelType = null;
            }
        }
        long userId = JwtUtil.getCurrentUserId();
        MLModelVO vo = new MLModelVO();
        vo.setStatus(status);
        vo.setUserId(userId);
        vo.setProjectId(projectId);
        List<MLModelVO> ret = this.queryByProjectId(request, vo).getResult();
        List<MLModelVO> retFiltered = new ArrayList<>();
        for (MLModelVO model:ret){
            if (isContainKey(model.getName(), searchKey) && isContainKey(model.getAlgorithm(), modelType)){
                retFiltered.add(model);
            }
        }
        return ApiResult.valueOf(retFiltered);
    }
}